import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
#from utils import train
import torchvision.models as models
import time
import matplotlib.pyplot as plt
import cvxpy as cvx
import scipy.io as scio
time_start=time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class Net_f(nn.Module):
    def __init__(self):
        super(Net_f, self).__init__()
        googlenet=models.googlenet(pretrained=True, progress=True)
        self.feature=torch.nn.Sequential(*list(googlenet.children())[0:18])
        self.fc1 = nn.Linear(1024,32)
        self.fc2 = nn.Linear(32,10)


    def forward(self,x):
        out=self.feature(x)
        out=out.view(-1,1024)
        out=F.relu(self.fc1(out))
        out=self.fc2(out)
        return out       

 

class Net_g(nn.Module):
    def __init__(self,num_class=2, dim=10):
        super(Net_g, self).__init__()

        self.fc=nn.Linear(num_class, dim)

    def forward(self,x):
        out=self.fc(x)

        return out

def corr(f,g):
    k = torch.mean(torch.sum(f*g,1))
    return k
    
def cov_trace(f,g):
    cov_f = torch.mm(torch.t(f),f) / (f.size()[0]-1.)
    cov_g = torch.mm(torch.t(g),g) / (g.size()[0]-1.)
    return torch.trace(torch.mm(cov_f, cov_g))

transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
    ])



groupa = CIFAR10('./data', train=True, transform=transform, download=True)
groupb = CIFAR10('./data', train=True, transform=transform, download=True)
groupc = CIFAR10('./data', train=True, transform=transform, download=True)
groupd = CIFAR10('./data', train=True, transform=transform, download=True)
groupe = CIFAR10('./data', train=True, transform=transform, download=True)
#train_data = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=transform, download=True)
#test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

#-------------preprocess----------------------------------------------------
a=np.array(groupa.targets)
data0=groupa.data[a==0]
data1=groupa.data[a==1]
a0=a[a==0]
a1=a[a==1]
rdchoice1=np.random.choice(3, 25)
rdchoice2=np.random.choice(3, 25)    
groupa.data=np.append(data0[rdchoice1],data1[rdchoice2],axis=0)
groupa.targets=np.append(a0[rdchoice1],a1[rdchoice2],axis=0).tolist()

b=np.array(groupb.targets)
groupb.data=groupb.data[(b>1)&(b<4)]
groupb.data=groupb.data[0:2000]
b=b[(b>1)&(b<4)]
where_0=np.where(b==2)
where_1=np.where(b==3)
b[where_0]=0
b[where_1]=1
groupb.targets=b[0:2000].tolist()

c=np.array(groupc.targets)
groupc.data=groupc.data[(c>3)&(c<6)]
groupc.data=groupc.data[0:2000]
c=c[(c>3)&(c<6)]
where_0=np.where(c==4)
where_1=np.where(c==5)
c[where_0]=0
c[where_1]=1
groupc.targets=c[0:2000].tolist()

d=np.array(groupd.targets)
groupd.data=groupd.data[(d>5)&(d<8)]
groupd.data=groupd.data[0:2000]
d=d[(d>5)&(d<8)]
where_0=np.where(d==6)
where_1=np.where(d==7)
d[where_0]=1
d[where_1]=0
groupd.targets=d[0:2000].tolist()

e=np.array(groupe.targets)
groupe.data=groupe.data[(e>7)&(e<10)]
groupe.data=groupe.data[0:2000]
e=e[(e>7)&(e<10)]
where_0=np.where(e==8)
where_1=np.where(e==9)
e[where_0]=0
e[where_1]=1
groupe.targets=e[0:2000].tolist()


test_labels=np.array(test_set.targets)
test_set.data=test_set.data[test_labels<2]
test_labels=test_labels[test_labels<2]
test_set.targets=test_labels.tolist()

#----------------
target= torch.utils.data.DataLoader(groupa, batch_size=50, shuffle=True)
testset=torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=True)


#------------------------------------
targetiter=iter(target)
samplest,labelst=next(targetiter)
labels_one_hot_t = torch.zeros(len(labelst), 2).scatter_(1, labelst.view(-1,1), 1)
lr=0.0001
#alpha0=torch.rand(5)
#alpha0=alpha0/alpha0.sum()
alpha0=np.array([0.15,0.20,0.01,0.02,0.62])
epoch=10
model_f = Net_f().to(device)
model_g = Net_g().to(device)
optimizer_fg = torch.optim.Adam(list(model_f.parameters())+list(model_g.parameters()),lr=lr)
losslist=[]
acclist=[0]
#----the first epoch
for i in range(epoch):
    
    sourceb=torch.utils.data.DataLoader(groupb, batch_size=25, shuffle=True)
    sourcec=torch.utils.data.DataLoader(groupc, batch_size=25, shuffle=True)
    sourced=torch.utils.data.DataLoader(groupd, batch_size=25, shuffle=True)
    sourcee=torch.utils.data.DataLoader(groupe, batch_size=25, shuffle=True)
    sourcebiter=iter(sourceb)
    sourceciter=iter(sourcec)
    sourcediter=iter(sourced)
    sourceeiter=iter(sourcee)
    losscc=[]
    for k in range(len(sourceb)):  
        model_f.train()
        model_g.train()
        optimizer_fg.zero_grad()
        samplesb,labelsb=next(sourcebiter)
        samplesc,labelsc=next(sourceciter)
        samplesd,labelsd=next(sourcediter)
        samplese,labelse=next(sourceeiter)

        labels_one_hot_b = torch.zeros(len(labelsb), 2).scatter_(1, labelsb.view(-1,1), 1)
        labels_one_hot_c = torch.zeros(len(labelsc), 2).scatter_(1, labelsc.view(-1,1), 1)
        labels_one_hot_d = torch.zeros(len(labelsd), 2).scatter_(1, labelsd.view(-1,1), 1)
        labels_one_hot_e = torch.zeros(len(labelse), 2).scatter_(1, labelse.view(-1,1), 1)

        ft=model_f(Variable(samplest).to(device))
        gt=model_g(Variable(labels_one_hot_t).to(device))
        fb=model_f(Variable(samplesb).to(device))
        gb=model_g(Variable(labels_one_hot_b).to(device))
        fc=model_f(Variable(samplesc).to(device))
        gc=model_g(Variable(labels_one_hot_c).to(device))
        fd=model_f(Variable(samplesd).to(device))
        gd=model_g(Variable(labels_one_hot_d).to(device))
        fe=model_f(Variable(samplese).to(device))
        ge=model_g(Variable(labels_one_hot_e).to(device))
    
        loss=alpha0[0].item()*(-2)*corr(ft,gt)
        loss+=alpha0[1].item()*(-2)*corr(fb,gb)
        loss+=alpha0[2].item()*(-2)*corr(fc,gc)
        loss+=alpha0[3].item()*(-2)*corr(fd,gd)
        loss+=alpha0[4].item()*(-2)*corr(fe,ge)
        loss+=2*((torch.sum(ft,0)/ft.size()[0])*(torch.sum(gt,0)/gt.size()[0])).sum()
        loss+=cov_trace(ft,gt)
        losscc.append(loss.item())
        loss.backward()
        optimizer_fg.step()

        model_f.eval()
        model_g.eval()

        acc=0
        total=0
        fc=model_f(Variable(samplest).to(device)).data.cpu().numpy()
        f_mean=np.sum(fc,axis=0)/fc.shape[0]
        labellist=torch.Tensor([[1,0],[0,1]])
        gc=model_g(Variable(labellist).to(device)).data.cpu().numpy()
        gce=np.sum(gc,axis=0)/gc.shape[0]
        gcp=gc-gce
        for k, data in enumerate(testset, 0):
            samples, labels = data
            labels= labels.numpy()
            fc=model_f(Variable(samples).to(device)).data.cpu().numpy()
            fcp=fc-f_mean
            fgp=np.dot(fcp,gcp.T)
            acc += (np.argmax(fgp, axis = 1) == labels).sum()
            total += len(samples)

        acc = float(acc) / total
        print(acc)
        if acc > 0.7:
           if acc > (max(acclist)):
              print('changepara')
              finalacc=acc
              paraf=model_f.state_dict()
              parag=model_g.state_dict()
        acclist.append(acc)



    losslist.append(sum(losscc)/len(losscc))
    print(sum(losscc)/len(losscc))
#-------------------renewalpha
    

#-----start loop------------
print(finalacc)
print(alpha0)
torch.save(paraf, './../mpara/cifar10f_alpha_6.pth')
torch.save(parag, './../mpara/cifar10g_alpha_6.pth')
time_end=time.time()
print(time_end-time_start)

'''
former = alphat**2/len(groupa)*np.trace((phitcovtphi @ np.linalg.inv(phitphi)))
former += alphab**2/len(groupb)*np.trace((phitcovbphi @ np.linalg.inv(phitphi)))
former += alphac**2/len(groupc)*np.trace((phitcovcphi @ np.linalg.inv(phitphi)))
former += alphad**2/len(groupd)*np.trace((phitcovdphi @ np.linalg.inv(phitphi)))
former += alphae**2/len(groupe)*np.trace((phitcovephi @ np.linalg.inv(phitphi)))
former += 2*alphab*alphac* (((phipsib-phipsic).T)@ np.linalg.inv(phitphi) @ (phipsib-phipsic))
former += 2*alphab*alphad* (((phipsib-phipsid).T)@ np.linalg.inv(phitphi) @ (phipsib-phipsid))
former += 2*alphab*alphae* (((phipsib-phipsie).T)@ np.linalg.inv(phitphi) @ (phipsib-phipsie))
former += 2*alphac*alphad* (((phipsic-phipsid).T)@ np.linalg.inv(phitphi) @ (phipsic-phipsid))
former += 2*alphac*alphae* (((phipsic-phipsie).T)@ np.linalg.inv(phitphi) @ (phipsic-phipsie))
former += 2*alphad*alphae* (((phipsid-phipsie).T)@ np.linalg.inv(phitphi) @ (phipsid-phipsie))
'''